"""SVM-map implementation of python module for SVM^python"""

import re, sys, itertools, svmlight, pdb, time, array

#####################################################################
# INPUT FORMATTING
# Assumes input file is an index file with names of data files
# Each example in a data file assumed to be of the format:
#      [label] qid:# feature_id:# feature_id:# ... feature_id:#
#
# For example:
#      1 qid:1 3:5 7:-3 12:4 22:5 159:9
#
# You can also see out the sample data file
#
# For more information, please read the REAMDE file.


#####################################################################
# NOTATION FOR THIS CODE
#
# To learn about the algorithm, please refer to Algorithm 2 in [1]
#
# This section covers notation and definitions used in this code to
# aid in mapping the algorithm description from that paper to this
# code file.
#
################################
#
# Input examples (denoted by 'x' in the code) each contain all the
# examples for a given query, and are stored as
#
#           x = (qid, feature_id_list, feature_value_list)
#
# where   qid is an integer denoting the query id
#         feature_id_list is an array of feature_ids objects
#         feature_value_list is an array of feature_values objects
#
# and     feature_ids is an array of feature (dimension) ids
#         feature_values is an array of corresponding feature values
#
################################
#
# Target labels (denoted by 'y' in the code) contain the relevance
# judgments for each example for a given query, and are stored as
#
#           y = (labels, type)
#
# where  labels is a list of target values corresponding to each
#           document (aka example) in x
#        type is an integer which indicates the type of label
#
# type values:
#        0 - true target labels (specified by dataset).  
#
#            If there are P relevant and N non-relevant documents in 
#            the dataset, then all relevant documents will store the 
#            label 1/P, and non-relevant documents will store the label 
#            -1/N.
#
#        1 - target labels generated by find_most_violated_constraint.
#
#            The find_most_violated_constraint function finds a 
#            total ordering of the documents.  We store this ordering
#            by considering the relative ordering of each relevant 
#            document against each non-relevant document.  For example, 
#            for some relevant document d+, let NA indicate the number
#            of non-relevant documents ordered after it and NB be the 
#            number of non-relevant documents ordered after it.  Then the
#            target score for d+ is (-NB + NA)/(P*N).  Note that if all
#            non-relevant documents are ranked after d+, then NB = 0 and
#            NA = N, so (-NB + NA)/(P*N) = 1/P.
#
#        2 - the target scores generated by the classification function
#
#            SVM-Map classifies documents by computing a score for each
#            document in the set.  The ranking is the one induced by
#            sorting on the ranking in descending order (you'll have to
#            sort the documents post-processing).


#####################################################################
# REFERENCES
#
# [1] "A Support Vector Machine for Optimizing Average Precision",
# by Y. Yue, T. Finley, F. Radlinski, and T. Joachims,
# In the Proceedings of SIGIR, 2007.  

svmpython_parameters={'index_from_one':False}

def parse_struct_parameters(sparm):
    """Sets attributes of sparm based on command line arguments.
    
    This gives the user code a chance to change sparm based on the
    custom command line arguments.  The command line arguments are
    stored in sparm.argv as a list of strings.  The command line
    arguments have also been preliminarily processed as sparm.argd as
    a dictionary.  For example, if the custom command line arguments
    were '--key1 value1 --key2 value2' then sparm.argd would equal
    {'key1':'value1', 'key2':'value2'}.  This function returns
    nothing.  It is called only during learning, not classification.

    If this function is not implemented, any custom command line
    arguments (aside from --m, of course) are ignored and sparm remains
    unchanged."""
    sparm.arbitrary_parameter='I am an arbitrary parameter! (THANK YOU TOM)'

def read_struct_examples(filename, sparm):
    """Reads and returns x,y example pairs from a file."""
    def read_struct_examples_byfulldata(filename):
        lines=[l.strip().split() for l in open(filename).readlines()]
        pre_qid=0
        examples=[]
        for line_idx, line in enumerate(lines):
            #print pre_qid,line_idx
            """ For each line, the first element is the label, the rest ones are features"""
            label=float(line[0])
            if(label<=0):
                label=-1
            if(label>1):
                label=1

            query_id=int(float(line[1]))
            if line_idx==0:
                ## change pre_qid
                pre_qid=query_id
                ## begin a new result                        
                feature_id_list=[]
                feature_value_list=[]
                relevance_list=[]
            elif line_idx==len(lines)-1:
                feature_ids=array.array('I')  #list of feature id's for a single document
                feature_values=array.array('f') #corresponding list of feature values
                for idx, itm in enumerate(line[2:]):
                    if float(itm)!=0:
                        feature_ids.append(idx+1)
                        feature_values.append(float(itm))

                feature_id_list.append(feature_ids)
                feature_value_list.append(feature_values)
                relevance_list.append(label)
                return examples
            elif query_id!=pre_qid:
                ## change pre_qid
                pre_qid=query_id
                ## save the previous result
                x=(query_id, feature_id_list, feature_value_list) # input examples
                y=(relevance_list, 0) # target labels -- see header comments for details
                ex=(x, y)

                number_of_negs=sum(1 for v in relevance_list if v<=0)
                number_of_pos=sum(1 for v in relevance_list if v>0)
                if number_of_negs==0 or number_of_pos==0:
                    print "zeros error:", number_of_negs, number_of_pos

                num_total=number_of_pos+number_of_negs
                psi_norm=1.0/float(number_of_pos*number_of_negs)

                # normalizing the target labels
                for i, v in enumerate(ex[1][0]):
                    if(v<=0):
                        ex[1][0][i]=float(-number_of_pos)*psi_norm
                    else:
                        ex[1][0][i]=float(number_of_negs)*psi_norm
                # Add the example to the example list.
                examples.append(ex)

                ## begin a new result                        
                feature_id_list=[]
                feature_value_list=[]
                relevance_list=[]

            feature_ids=array.array('I')  #list of feature id's for a single document
            feature_values=array.array('f') #corresponding list of feature values
            for idx, itm in enumerate(line[2:]):
                if float(itm)!=0:
                    feature_ids.append(idx+1)
                    feature_values.append(float(itm))

            feature_id_list.append(feature_ids)
            feature_value_list.append(feature_values)
            relevance_list.append(label)

    def read_struct_examples_byindex(filename):
    # Helper function for reading from files.
        def line_reader(lines):
            # Given lines, return only non-empty lines with comments stripped.
            for l in lines:
                i=l.find('#')
                if i!=-1:
                    l=l[:i]
                l=l.strip()
                if l:
                    yield l
        examples=[]
        # reading from the index file, looping through each data file name
        for input_file in file(filename):
            if input_file[0]=='#':
                continue
            #print 'Loading File ',input_file.strip()
            # grabbing the data from each data file
            inputer=(line.split() for line in line_reader(file(input_file.strip())))

            feature_id_list=[]
            feature_value_list=[]
            relevance_list=[]

            # looping through the data in a data file
            for line in inputer:
                feature_tokens=(token.split(':') for token in line[2:])

                feature_ids=array.array('I')  #list of feature id's for a single document
                feature_values=array.array('f') #corresponding list of feature values

                for t in feature_tokens:
                    feature_ids.append(int(t[0]))
                    feature_values.append(float(t[1]))
                label=float(line[0])
                if(label<=0):
                    label=-1
                if(label>1):
                    label=1
                assert(line[1].startswith('qid:'))
                query_id=int(line[1][4:])
                feature_id_list.append(feature_ids)
                feature_value_list.append(feature_values)
                relevance_list.append(label)

            x=(query_id, feature_id_list, feature_value_list) # input examples
            y=(relevance_list, 0) # target labels -- see header comments for details
            ex=(x, y)

            number_of_negs=sum(1 for v in relevance_list if v<=0)
            number_of_pos=sum(1 for v in relevance_list if v>0)
            if number_of_negs==0 or number_of_pos==0:
                print "zeros error:", number_of_negs, number_of_pos

            num_total=number_of_pos+number_of_negs
            psi_norm=1.0/float(number_of_pos*number_of_negs)
            # normalizing the target labels
            for i, v in enumerate(ex[1][0]):
                if(v<=0):
                    ex[1][0][i]=float(-number_of_pos)*psi_norm
                else:
                    ex[1][0][i]=float(number_of_negs)*psi_norm
            # Add the example to the example list.
            examples.append(ex)
        return examples
    #fileType: 0 or empty->the data file(default) 1->the index file
    filetypeline=[l.strip() for l in open(filename).readlines() if l.startswith('#type')]
    if len(filetypeline)==0:
        return read_struct_examples_byfulldata(filename)
    else:
        filetype=filetypeline[0][5]
        if filetype=='0':
            return read_struct_examples_byfulldata(filename)
        elif filetype=='1':
            return read_struct_examples_byindex(filename)
        else:
            print "ERROR FILE TYPE"
            return 0

def init_struct_model(sample, sm, sparm):
    """Initializes the learning model.
    
    Initialize the structure model sm.  The major intention is that we
    set sm.size_psi to the number of features.  The ancillary purpose
    is to add any information to sm that is necessary from the user
    code perspective.  This function returns nothing."""
    # In our binary classification task, we've encoded a pattern as a
    # list of four features.  We just want a linear rule, so we have a
    # weight corresponding to each feature.  We also add one to allow
    # for a last "bias" feature.

    sm.size_psi=max(max(feature_ids)
                      for (qid, feature_id_list, feature_value_list), (labels, type) in sample
                      for feature_ids in feature_id_list if len(feature_ids)>0)+1
    print 'Size of psi is', sm.size_psi

def init_struct_constraints(sample, sm, sparm):
    """Initializes special constraints.

    Returns a sequence of initial constraints.  Each constraint in the
    returned sequence is itself a sequence with two items (the
    intention is to be a tuple).  The first item of the tuple is a
    document object, with at least its fvec attribute set to a support
    vector object, or list of support vector objects.  The second item
    is a number, indicating that the inner product of the feature
    vector of the document object with the linear weights must be
    greater than or equal to the number (or, in the nonlinear case,
    the evaluation of the kernel on the feature vector with the
    current model must be greater).  This initializes the optimization
    problem by allowing the introduction of special constraints.
    Typically no special constraints are necessary.

    Note that the docnum attribute of each document returned by the
    user is ignored.  Also, regarding the slackid of each document,
    the slack IDs 0 through len(sample)-1 are reserved for each
    training example in the sample.  Note that if you leave the
    slackid of a document as None, which is the default for
    svmlight.create_doc, that the document encoded as a constraint
    will get slackid=len(sample)+i, where i is the position of the
    constraint within the returned list.

    If this function is not implemented, it is equivalent to returning
    an empty list, i.e., no constraints."""
    pass

def classify_struct_example(x, sm, sparm):
    """Given a pattern x, return the predicted label."""
    # Given a list of features, return the list of 'relevances.'
    (qid, feature_id_list, feature_value_list)=x

    len_of_model=len(sm.w)
    ans=[sum(sm.w[k]*v
               for k, v in itertools.izip(feature_ids, feature_values) if k<len_of_model)
           for (feature_ids, feature_values) in itertools.izip(feature_id_list, feature_value_list)
          ]

    return ans, 1
    # the 1 at the end signifies the type of y (the prediction kind)
    # see header comments for y type definitions

def find_most_violated_constraint(x, y, sm, sparm):
    """Return ybar associated with x's most violated constraint.
    
    Returns the label ybar for pattern x corresponding to the most
    violated constraint according to SVM^struct cost function.  To
    find which cost function you should use, check sparm.loss_type for
    whether this is slack or margin rescaling (1 or 2 respectively),
    and check sparm.slack_norm for whether the slack vector is in an
    L1-norm or L2-norm in the QP (1 or 2 respectively).  If there's no
    incorrect label, then return None.

    If this function is not implemented, this function is equivalent
    to 'classify(x, sm, sparm)'.  The guarantees of optimality of
    Tsochantaridis et al. no longer hold since this doesn't take the
    loss into account at all, but it isn't always a terrible
    approximation, and indeed impiracally speaking on many clustering
    problems I have looked at it doesn't yield a statistically
    significant difference in performance on a test set."""

    print "Finding Max Constraint..."

    # Build a list of tuples of true c_i, predicted w^T * x_i, and i.
    cost_objective=0.0
    y2, y_type=y;
    temp_score=[ss for ss in classify_struct_example(x, sm, sparm)[0]]

    # documents are intially sorted in perfectly, with relevant
    # documents ahead of non-relevant documents, and secondarily by
    # the classification score in descending order.
    orders=sorted(zip(y2, temp_score, range(len(y2))))[::-1]

    num_negatives=sum(1 for a, b, c in orders if a<=0)
    num_total=len(orders)
    num_positives=num_total-num_negatives
    psi_norm=1.0/float(num_positives*num_negatives)

    # swapping algorithm for finding the most violated constraint
    # refer to [1] (see header comments) for description

    # begin by looping through each non-relevant document
    for j in xrange(num_positives, len(orders)):
        this_is_zero, wTx_neg, original_index=orders[j]

        # Default is to swap with itself 
        max_delta, delta_cost, to_swap_index=0.0, 0.0, j
        #max_loss, max_psi = 0.0, 0.0
        delta_loss, delta_psi=0.0, 0.0

        # looping through all documents currently ranked before document j
        for i in reversed(xrange(j)):

            # we consider the benefit of swapping documents i and j
            to_swap_label, wTx_pos, to_swap_original_index=orders[i]

            # we never change the relative ordering of the non-relevant
            # documents.  So we quit this loop if we document i is
            # also non-relevant
            if to_swap_label<=0: break

            jn=float(j-num_positives+1)
            ii=i+1

            # change in the loss function
            delta_loss=1.0/num_positives*(jn/(ii+1)-(jn-1)/(ii))

            # change in the discriminant function
            delta_psi=-2.0*(wTx_pos-wTx_neg)*psi_norm

            # change in the object function
            delta_delta_cost=delta_loss+delta_psi
            delta_cost+=delta_delta_cost

            if delta_cost>max_delta:
                # This is the current best swap position.
                max_delta, to_swap_index=delta_cost, i
                #max_loss, max_psi = delta_loss, delta_psi

        # if we should swap, then swap
        # if we shouldn't swap, then no non-relevant documents
        # ranked later will swap, so we quit
        if j==to_swap_index: break
        removed_boy=orders[j]
        orders[to_swap_index+1:j+1]=orders[to_swap_index:j]
        orders[to_swap_index]=removed_boy


    # Now we have the swapped order.  Compute the target labels.
    clist=[0]*len(y2)
    num_zeros, num_ones=0, 0

    for true_label, pred, original_index in reversed(orders):
        if true_label<=0: num_zeros+=1
        else: clist[original_index]=(2*num_zeros-num_negatives)*psi_norm
    for true_label, pred, original_index in orders:
        if true_label>0: num_ones+=1
        else: clist[original_index]=(num_positives-2*num_ones)*psi_norm

    return clist, 2
    # the 2 at the end signifies the type of y (the int kind)
    # see header comments for y type definitions

def psi(x, y, sm, sparm):
    """Return a feature vector describing pattern x and label y.
    
    This returns a sequence representing the feature vector describing
    the relationship between a pattern x and label y.  What psi is
    depends on the problem.  Its particulars are described in the
    Tsochantaridis paper.  The return value should be either a support
    vector object of the type returned by svmlight.create_svector, or
    a list of such objects."""

    #we ignore the type of y in this function
    y2, y_type=y;

    qid, feature_id_list, feature_value_list=x
    # Add in the weight for the bias term for each example.
    thepsi={}
    for i, label in enumerate(y2):
        for k, v in itertools.izip(feature_id_list[i], feature_value_list[i]):
            thepsi[k]=thepsi.get(k, 0.0)+label*v
    intermediate_psi=sorted(thepsi.items())
    return svmlight.create_svector(intermediate_psi)

def loss(y, ybar, sparm):
    """Return the loss of ybar relative to the true labeling y.
    
    Returns the loss for the correct label y and the predicted label
    ybar.  In the event that y and ybar are identical loss must be 0.
    Presumably as y and ybar grow more and more dissimilar the
    returned value will increase from that point.  sparm.loss_function
    holds the loss function option specified on the command line via
    the -l option.

    If this function is not implemented, the default behavior is to
    perform 0/1 loss based on the truth of y==ybar."""

    # Check if ybar is a "ranked" prediction, that is, predictions
    # where each entry is a natural number instead of one of these c
    # scores. Refer to the header comments for different types of ybar

    # we care about the type of ybar in this function
    y2, y_type=y
    ybar2, ybar_type=ybar

    if ybar2 and ybar_type==1:
        negy=[-v for v in y2]
        relevant_so_far, average_precision=0, 0.0
        i=0

        # we sort by classification score, and secondarily by true label value, 
        # both in descending order
        for i, (pred, dummy, true) in enumerate(reversed(sorted(zip(ybar2, negy, y2)))):
           if (true>0):
              relevant_so_far+=1
              average_precision+=float(relevant_so_far)/(i+1)
        if(relevant_so_far==0):
           return 1
        lossv=1.0-average_precision/relevant_so_far
        #return lossv if lossv < 0.67 else 1
        return lossv

    # Otherwise it's the sort of integral C representation.

    #if ybar2 and ybar_type == 2
    average_precision=0.0
    i=0
    num_total=float(len(y2))
    num_negatives=sum(1 for a in y2 if a<=0)
    num_positives=num_total-num_negatives
    inv_psi_norm=float(num_positives*num_negatives)

    # only consider relevant documents, and sort in descending order
    for i, (true, pred) in enumerate(reversed(sorted((t, p) for t, p in zip(y2, ybar2) if t>0))):

       # computing the number of negative documents ranked above relevant document i
       neg_above=(num_negatives-pred*inv_psi_norm)*0.5

       average_precision+=float(i+1)/(i+1+neg_above)

    lossv=1.0-average_precision/(i+1)
    #return lossv if lossv < 0.67 else 1
    return lossv

def print_struct_learning_stats(sample, sm, cset, alpha, sparm):
    """Print statistics once learning has finished.
    
    This is called after training primarily to compute and print any
    statistics regarding the learning (e.g., training error) of the
    model on the training sample.  You may also use it to make final
    changes to sm before it is written out to a file.  For example, if
    you defined any non-pickle-able attributes in sm, this is a good
    time to turn them into a pickle-able object before it is written
    out.  Also passed in is the set of constraints cset as a sequence
    of (left-hand-side, right-hand-side) two-element tuples, and an
    alpha of the same length holding the Lagrange multipliers for each
    constraint.

    If this function is not implemented, the default behavior is
    equivalent to:
    'print [loss(e[1], classify(e.[0], sm, sparm)) for e in sample]'."""
    training_errors=[
        loss(e[1], classify_struct_example(e[0], sm, sparm), sparm)
        for e in sample]
    print 'training errors are',
    print training_errors

def print_struct_testing_stats(sample, sm, sparm, teststats):
    """Print statistics once classification has finished.
    
    This is called after all test predictions are made to allow the
    display of any summary statistics that have been accumulated in
    the teststats object through use of the eval_prediction function.

    If this function is not implemented, the default behavior is
    equivalent to 'print teststats'."""
    print "teststats=", ['%.4f'%(itm) for itm in teststats]
    #print 'number_of_Label = ', sum([1/(1-itm) for itm in teststats])/len(teststats)

def eval_prediction(exnum, x, y, ypred, sm, sparm, teststats):
    """Accumulate statistics about a single training example.    
    Allows accumulated statistics regarding how well the predicted
    label ypred for pattern x matches the true label y.  The first
    time this function is called teststats is None.  This function's
    return value will be passed along to the next call to
    eval_prediction.  After all test predictions are made, the last
    value returned will be passed along to print_testing_stats.

    If this function is not implemented, the default behavior is
    equivalent to initialize teststats as an empty list on the first
    example, and thence for each prediction appending the loss between
    y and ypred to teststats, and returning teststats."""
    if exnum==0: teststats=[]
    lv=loss(y, ypred, sparm)
    #print 'on example',exnum,'predicted',ypred,'where correct is',y    
    ypred_list, ypred_type=ypred
    y_list, y_type=y
    rank_pred=sorted(range(len(ypred_list)), key=ypred_list.__getitem__, reverse=True)
    rank_ori=sorted(range(len(y_list)), key=y_list.__getitem__, reverse=True)
    #print 'Example','%04d'%(exnum),'loss=','%.4f'%(float(lv)), 'predicted',ypred,'where correct is',y
    print 'Example', '%04d'%(exnum), 'loss=', '%.4f'%(float(lv)), 'predicted', rank_pred, 'where correct is', rank_ori
    teststats.append(lv)
    return teststats

def write_struct_model(filename, sm, sparm):
    """Dump the structmodel sm to a file.
    
    Write the structmodel sm to a file at path filename.

    If this function is not implemented, the default behavior is
    equivalent to 'pickle.dump(sm, file(filename,'w'))'."""
    import pickle
    #print sm.w
    f=open('%s.w'%(filename), 'w')
    f.write(repr(sm.w))
    f.close()
    f=file(filename, 'w')
    pickle.dump(sm, f)
    f.close()

def read_struct_model(filename, sparm):
    #print "Debug: Loading Model",filename
    """Load the structure model from a file.
    
    Return the structmodel stored in the file at path filename, or
    None if the file could not be read for some reason.

    If this function is not implemented, the default behavior is
    equivalent to 'return pickle.load(file(filename))'."""
    import pickle
    model=pickle.load(file(filename))
    print "\nW=", ['%.4f'%(float(wi)) for wi in model.w]
    #print "\nW = ",model.w	
    return model

def write_label(fileptr, y):
    """Write a predicted label to an open file.
    
    Called during classification, the idea is to write a string
    representation of y to the file fileptr.  Note that unlike other
    functions, fileptr an actual open file, not a filename.  It is not
    to be closed by this function.  Any attempt to close it is
    ignored.

    If this function is not implemented, the default behavior is
    equivalent to 'fileptr.write(repr(y)+'\\n')'."""
    y2, y_type=y
    fileptr.write(repr(y2)+"\\n");

def write_label_byname(filename, y):
    y2, y_type=y
    f=file(filename, "a")
    f.write(repr(y2)+'\n');
    f.close()
    #print "in python",filename

def print_struct_help():
    """Help printed for badly formed CL-arguments when learning.

    If this function is not implemented, the program prints the
    default SVM^struct help string as well as a note about the use of
    the --m option to load a Python module."""
    print """Help!  I need somebody.  Help!  Not just anybody.
    Help!  You know, I need someone.  Help!"""

